//
//  NeuralNetworkShapes.cpp
//  mlmodel
//
//  Created by William March on 12/5/17.
//  Copyright © 2017 Apple Inc. All rights reserved.
//

#include "NeuralNetworkShapes.hpp"
#include "Utils.hpp"

using namespace CoreML;

void NeuralNetworkShaper::shapeConvolutionLayer(const Specification::NeuralNetworkLayer& specLayer) {

    // one input, one output
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

    // Check if this already exists? Can't -- outputs are only generated by one layer
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << std::endl << "Convolution layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Convolution layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    // This can't grow the ranges if they were already set
    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    const Specification::ConvolutionLayerParams& conv = specLayer.convolution();

    //validate input shape
    int n_groups = (conv.ngroups() == 0) ? 1 : static_cast<int>(conv.ngroups());
    int K = 0;
    if (conv.isdeconvolution()) {
        K = static_cast<int>(conv.kernelchannels()); //number of input channels: must match channels of previous data blob
    } else {
        K = static_cast<int>(conv.kernelchannels()) * n_groups;  //number of input channels: must match channels of previous data blob
    }

    // There is no flexibility in the number of input channels
    inputShape.setChannel(static_cast<size_t>(K));

    //compute top shape
    int Kh, Kw, hstride, wstride, hdilation, wdilation;
    Kh = Kw = 3;
    hstride = wstride = hdilation = wdilation = 1;

    if (conv.kernelsize_size() == 2){
        Kh = static_cast<int>(conv.kernelsize(0)); //height
        Kw = static_cast<int>(conv.kernelsize(1)); //width
    }
    if (conv.stride_size() == 2){
        hstride = static_cast<int>(conv.stride(0)); //height
        wstride = static_cast<int>(conv.stride(1)); //width
    }
    if (conv.dilationfactor_size() == 2){
        hdilation = static_cast<int>(conv.dilationfactor(0)); //height
        wdilation = static_cast<int>(conv.dilationfactor(1)); //width
    }
    int Kh_dilated = (Kh-1) * hdilation + 1;
    int Kw_dilated = (Kw-1) * wdilation + 1;

    outputShape.setChannel(static_cast<size_t>(conv.outputchannels()));
    int l, r, b, t;
    l = r = b = t = 0;

    if (conv.isdeconvolution() && conv.outputshape_size() == 2){
        outputShape.setHeight(static_cast<size_t>(conv.outputshape(0)));
        outputShape.setWidth(static_cast<size_t>(conv.outputshape(1)));
    }
    else {
        switch (conv.ConvolutionPaddingType_case()) {
            case Specification::ConvolutionLayerParams::kValid:
                if (conv.valid().paddingamounts().borderamounts_size() == 2){
                    t = static_cast<int>(conv.valid().paddingamounts().borderamounts(0).startedgesize());
                    b = static_cast<int>(conv.valid().paddingamounts().borderamounts(0).endedgesize());
                    l = static_cast<int>(conv.valid().paddingamounts().borderamounts(1).startedgesize());
                    r = static_cast<int>(conv.valid().paddingamounts().borderamounts(1).endedgesize());
                }
                if (conv.isdeconvolution()) {
                    int lowerBound = ((static_cast<int>(outputShape.heightRange().minimumValue()) + b + t - Kh_dilated) / hstride) + 1;
                    RangeValue upperBound = (outputShape.heightRange().maximumValue() + b + t - Kh_dilated) / static_cast<size_t>(hstride) + 1;
                    if (lowerBound > 0)
                        inputShape.lowerBoundHeight(static_cast<size_t>(lowerBound));
                    inputShape.upperBoundHeight(upperBound);
                    // Problem: clamping this is not associative, so what is the right order? But we already made sure it's clamped correctly above, so it's ok
                    outputShape.updateHeightRange((inputShape.heightRange() - 1) * static_cast<size_t>(hstride) + Kh_dilated - t - b);
                    lowerBound = ((static_cast<int>(outputShape.widthRange().minimumValue()) + l + r - Kw_dilated) / wstride) + 1;
                    upperBound = (outputShape.widthRange().maximumValue() + l + r - Kw_dilated) / static_cast<size_t>(wstride) + 1;
                    if (lowerBound > 0)
                        inputShape.lowerBoundWidth(static_cast<size_t>(lowerBound));
                    inputShape.upperBoundWidth(upperBound);
                    outputShape.updateWidthRange((inputShape.widthRange() - 1) * static_cast<size_t>(wstride) + Kw_dilated - r - l);
                } else {
                    if (Kh_dilated - b - t > 0) { // what's the else here? just doing the convolution on the padding?
                        size_t inputLowerBound = outputShape.heightRange().minimumValue() == 0 ? 0 : (outputShape.heightRange().minimumValue() - 1);
                        inputShape.lowerBoundHeight(inputLowerBound * static_cast<size_t>(hstride) + static_cast<size_t>(Kh_dilated - b - t));
                        
                        RangeValue inputUpperBound = (outputShape.heightRange().maximumValue() - 1) * static_cast<size_t>(hstride) + static_cast<size_t>(Kh_dilated - b - t);
                        // We need to account for the integer division here
                        if (!inputShape.heightRange().maximumValue().isUnbound() && (inputShape.heightRange().maximumValue().value() + static_cast<size_t>(t + b - Kh_dilated) % 2 != 0)) {
                            inputUpperBound = inputUpperBound + 1;
                        }
                        inputShape.upperBoundHeight(inputUpperBound);
                        
                    }
                    outputShape.updateHeightRange((inputShape.heightRange() + (t + b - Kh_dilated))/hstride + 1);
                    if (Kw_dilated - l - r > 0) {
                        size_t inputLowerBound = outputShape.widthRange().minimumValue() == 0 ? 0 : (outputShape.widthRange().minimumValue() - 1);
                        inputShape.lowerBoundWidth(inputLowerBound * static_cast<size_t>(wstride) + static_cast<size_t>(Kw_dilated - l - r));
                        
                        RangeValue inputUpperBound = (outputShape.widthRange().maximumValue() - 1) * static_cast<size_t>(wstride) + static_cast<size_t>(Kw_dilated - l - r);
                        if (!inputShape.widthRange().maximumValue().isUnbound() && (inputShape.widthRange().maximumValue().value() + static_cast<size_t>(l + r - Kw_dilated) % 2 != 0)) {
                            inputUpperBound = inputUpperBound + 1;
                        }
                        inputShape.upperBoundWidth(inputUpperBound);
                    }
                    outputShape.updateWidthRange((inputShape.widthRange() + (r + l - Kw_dilated)) / wstride + 1);
                }
                break;
            case Specification::ConvolutionLayerParams::kSame:
                if (conv.isdeconvolution()) {
                    outputShape.updateHeightRange(inputShape.heightRange() * static_cast<size_t>(hstride));
                    outputShape.updateWidthRange(inputShape.widthRange() * static_cast<size_t>(wstride));
                } else {
                    // needs to round up
                    outputShape.updateHeightRange((inputShape.heightRange() - 1) / static_cast<size_t>(hstride)+1);
                    inputShape.lowerBoundHeight(outputShape.heightRange().minimumValue());
                    // If the output has an upper bound, then the input can't be any larger
                    inputShape.upperBoundHeight(outputShape.heightRange().maximumValue() * static_cast<size_t>(hstride));
                    outputShape.updateWidthRange((inputShape.widthRange() - 1) / static_cast<size_t>(wstride)+1);
                    inputShape.lowerBoundWidth(outputShape.widthRange().minimumValue());
                    inputShape.upperBoundWidth(outputShape.widthRange().maximumValue() * static_cast<size_t>(wstride));
                }
                break;
            case Specification::ConvolutionLayerParams::CONVOLUTIONPADDINGTYPE_NOT_SET:
                throw std::runtime_error("Convolution padding type not set");
        }
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Convolution layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Convolution layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

} // convolution shape

void NeuralNetworkShaper::shapePoolingLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Pooling layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Pooling layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));
    outputShape.updateChannelRange(inputShape.channelRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange());

    const Specification::PoolingLayerParams& pool = specLayer.pooling();

    int Kh, Kw, hstride, wstride;
    Kh = Kw = 3;
    hstride = wstride = 1;
    if (pool.kernelsize_size() == 2){
        Kh = static_cast<int>(pool.kernelsize(0)); //height
        Kw = static_cast<int>(pool.kernelsize(1)); //width
    }
    if (pool.stride_size() == 2){
        hstride = static_cast<int>(pool.stride(0)); //height
        wstride = static_cast<int>(pool.stride(1)); //width
    }

    int l, r, b, t;
    l = r = b = t = 0;

    if (pool.globalpooling()) {
        outputShape.setHeight(1);
        outputShape.setWidth(1);
    } else {
        switch (pool.PoolingPaddingType_case()) {
            case Specification::PoolingLayerParams::kValid:
                if (pool.valid().paddingamounts().borderamounts_size() == 2){
                    t = static_cast<int>(pool.valid().paddingamounts().borderamounts(0).startedgesize());
                    b = static_cast<int>(pool.valid().paddingamounts().borderamounts(0).endedgesize());
                    l = static_cast<int>(pool.valid().paddingamounts().borderamounts(1).startedgesize());
                    r = static_cast<int>(pool.valid().paddingamounts().borderamounts(1).endedgesize());
                }

                if (Kh - b - t > 0) {
                    size_t inputLowerBound = outputShape.heightRange().minimumValue() == 0 ? 0 : (outputShape.heightRange().minimumValue() - 1);
                    inputShape.lowerBoundHeight(inputLowerBound * static_cast<size_t>(hstride) + static_cast<size_t>(Kh - b - t));
//                    inputShape.lowerBoundHeight(inputLowerBound + static_cast<size_t>(Kh - b - t));
                }

                outputShape.updateHeightRange((inputShape.heightRange() + (t + b - Kh))/static_cast<size_t>(hstride) + 1);

                if (Kw - l - r > 0) {
                    size_t inputLowerBound = outputShape.widthRange().minimumValue() == 0 ? 0 : (outputShape.widthRange().minimumValue() - 1);
                    inputShape.lowerBoundWidth(inputLowerBound * static_cast<size_t>(wstride) + static_cast<size_t>(Kw - l - r));
//                    inputShape.lowerBoundWidth(inputLowerBound + static_cast<size_t>(Kw - l - r));
                }
                outputShape.updateWidthRange((inputShape.widthRange() + (r + l - Kw))/static_cast<size_t>(wstride) + 1);

                break;
            case Specification::PoolingLayerParams::kSame:
                // needs to round up
                outputShape.updateHeightRange((inputShape.heightRange() - 1)/ static_cast<size_t>(hstride) + 1);
                outputShape.updateWidthRange((inputShape.widthRange() - 1) / static_cast<size_t>(wstride) + 1);
                break;
            case Specification::PoolingLayerParams::kIncludeLastPixel: {
                if (pool.includelastpixel().paddingamounts_size() == 2){
                    t = static_cast<int>(pool.includelastpixel().paddingamounts(0));
                    l = static_cast<int>(pool.includelastpixel().paddingamounts(1));
                }
                // subtracting in the numerator and adding 1 in order to get integer division rounded up
                int addExtra = 0;
                if (t || l) {
                    if ((static_cast<int>(inputShape.minimumHeight()) - 1 + (2*t - Kh - 1) + 2 * hstride) >= (static_cast<int>(inputShape.minimumHeight()) + t)) {
                        addExtra = -1;
                    }
                }
                outputShape.updateHeightRange((inputShape.heightRange() + (2*t - Kh)).divideAndRoundUp(static_cast<size_t>(hstride)) + 1 + addExtra);
                if (Kh - 2*t > 0)
                    inputShape.lowerBoundHeight(outputShape.heightRange().minimumValue() - 1 + static_cast<size_t>(Kh - 2*t));

                addExtra = 0;
                if (t || l){
                    if ((static_cast<int>(inputShape.minimumWidth()) - 1 + (2*l - Kw - 1) + 2 * wstride) >= (static_cast<int>(inputShape.minimumWidth()) + l)) {
                        addExtra = -1;
                    }
                }
                outputShape.updateWidthRange((inputShape.widthRange() + (2*l - Kw)).divideAndRoundUp(static_cast<size_t>(wstride)) + 1 + addExtra);
                if (Kw - 2*l > 0)
                    inputShape.lowerBoundHeight(outputShape.widthRange().minimumValue() - 1 + static_cast<size_t>(Kw - 2*l));
                break;
            }
            case Specification::PoolingLayerParams::POOLINGPADDINGTYPE_NOT_SET:
                throw std::runtime_error("Pooling padding type not set");
        }
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Pooling layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Pooling layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif


}

// For layers which do not alter the input shape
void NeuralNetworkShaper::shapeUnchanged(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    // in case it wasn't done before
    outputShape.setName(specLayer.output(0));

    // Does the intersection for all 5 axes
    outputShape.copyFrom(inputShape);
    inputShape.copyFrom(outputShape);

}

void NeuralNetworkShaper::shapeInnerProductLayer(const Specification::NeuralNetworkLayer& specLayer) {

    // TODO: need to forward the other constraints -- sequence and batch (although those can't exist for this layer)
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Inner Product layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Inner Product layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    const Specification::InnerProductLayerParams& ip = specLayer.innerproduct();

    inputShape.setChannel((size_t)ip.inputchannels());
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel((size_t)ip.outputchannels());
    outputShape.setHeight(1);
    outputShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Inner Product layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Inner Products layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif


}

void NeuralNetworkShaper::shapeEmbeddingLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

    // Check if this already exists? Can't -- outputs are only generated by one layer
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Embedding layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Embedding layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    Specification::EmbeddingLayerParams embedding = specLayer.embedding();

    // all three inputs need to be one
    inputShape.setChannel(1);
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel((size_t)embedding.outputchannels());
    outputShape.setHeight(1);
    outputShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Embedding layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Embedding layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeCropLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Crop layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Crop layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));
    outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange());

    Specification::CropLayerParams crop = specLayer.crop();

    int l , r, t, b;
    l = r = t = b = 0;
    if (specLayer.input_size() == 1){
        if (crop.cropamounts().borderamounts_size() == 2){
            t = static_cast<int>(crop.cropamounts().borderamounts(0).startedgesize());
            b = static_cast<int>(crop.cropamounts().borderamounts(0).endedgesize());
            l = static_cast<int>(crop.cropamounts().borderamounts(1).startedgesize());
            r = static_cast<int>(crop.cropamounts().borderamounts(1).endedgesize());
        }
        outputShape.updateHeightRange(inputShape.heightRange() - t - b);
        outputShape.updateWidthRange(inputShape.widthRange() - r - l);
        inputShape.lowerBoundHeight(static_cast<size_t>(t+b+1));
        inputShape.lowerBoundWidth(static_cast<size_t>(r+l+1));
    } else {
        // the channel information on the second output is ignored
        outputShape.updateChannelRange(inputShape.channelRange());
        outputShape.updateHeightRange(blobShapes[specLayer.input(1)].heightRange());
        outputShape.updateWidthRange(blobShapes[specLayer.input(1)].widthRange());
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Crop layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Crop layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapePaddingLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Padding layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Padding layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange());

    Specification::PaddingLayerParams padding = specLayer.padding();

    size_t l, r, t, b;
    l = r = t = b = 0;
    if (padding.paddingamounts().borderamounts_size() == 2){
        t = (size_t)padding.paddingamounts().borderamounts(0).startedgesize();
        b = (size_t)padding.paddingamounts().borderamounts(0).endedgesize();
        l = (size_t)padding.paddingamounts().borderamounts(1).startedgesize();
        r = (size_t)padding.paddingamounts().borderamounts(1).endedgesize();
    }
    outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));
    outputShape.updateHeightRange(inputShape.heightRange() + t + b);
    outputShape.updateWidthRange(inputShape.widthRange() + r + l);
    
    // If a fixed constraint for the output shape emerges later in the network, then we can infer a lower bound for the input shape
    inputShape.updateHeightRange(outputShape.heightRange() - t - b);
    inputShape.updateWidthRange(outputShape.widthRange() - l - r);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Padding layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Padding layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeUpsampleLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Upsample layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Upsample layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));
    outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange());

    Specification::UpsampleLayerParams upsample = specLayer.upsample();

    size_t scaling_factor_h = 1;
    size_t scaling_factor_w = 1;

    if (upsample.scalingfactor_size() == 2) {
        scaling_factor_h = (upsample.scalingfactor(0) == 0) ? 1 : (size_t)upsample.scalingfactor(0); //height
        scaling_factor_w = (upsample.scalingfactor(1) == 0) ? 1 : (size_t)upsample.scalingfactor(1); //width
    }

    outputShape.updateHeightRange(inputShape.heightRange() * scaling_factor_h);
    outputShape.updateWidthRange(inputShape.widthRange() * scaling_factor_w);
    
    inputShape.updateHeightRange(outputShape.heightRange() / scaling_factor_h);
    inputShape.updateWidthRange(outputShape.widthRange() / scaling_factor_w);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Upsample layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Upsample layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeBroadcastLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& maxShape = blobShapes[specLayer.input(0)];

    ShapeRange maxSequence = maxShape.sequenceRange();
    ShapeRange maxBatch = maxShape.batchRange();
    ShapeRange maxChannel = maxShape.channelRange();
    ShapeRange maxHeight = maxShape.heightRange();
    ShapeRange maxWidth = maxShape.widthRange();

    bool inputShapesFixed = maxShape.hasFixedCHW();

    for (int i = 1; i < specLayer.input_size(); i++) {

        ShapeConstraint& inShape = blobShapes[specLayer.input(i)];

        inputShapesFixed = inputShapesFixed && inShape.hasFixedCHW();

        maxSequence = maxSequence.intersect(inShape.sequenceRange());
        maxBatch = maxBatch.intersect(inShape.batchRange());

        maxChannel = maxChannel.unify(inShape.channelRange());
        maxHeight = maxHeight.unify(inShape.heightRange());
        maxWidth = maxWidth.unify(inShape.widthRange());

    }

    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Broadcast layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Broadcast layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    // Can there be ambiguity in the output of one of these? Yes.

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(maxSequence));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(maxBatch));

    if (inputShapesFixed) {
        outputShape.setChannel(maxChannel.maximum().value());
        outputShape.setHeight(maxHeight.maximum().value());
        outputShape.setWidth(maxWidth.maximum().value());
    }
    else {
        outputShape.updateChannelRange(outputShape.channelRange().intersect(maxChannel));
        outputShape.updateHeightRange(outputShape.heightRange().intersect(maxHeight));
        outputShape.updateWidthRange(outputShape.widthRange().intersect(maxWidth));
    }

    for (int i = 0; i < specLayer.input_size(); i++) {

        ShapeConstraint& inShape = blobShapes[specLayer.input(i)];
        inShape.updateSequenceRange(outputShape.sequenceRange());
        inShape.updateBatchRange(outputShape.batchRange());

        // Load constant followed by broadcast is a common pattern. If one of the inputs was from a
        // load constant layer, it will have a fixed shape, so we don't need to update shape information
        // in this case.

        if (!inShape.hasFixedCHW()) {
            inShape.updateChannelRange(outputShape.channelRange());
            inShape.updateHeightRange(outputShape.heightRange());
            inShape.updateWidthRange(outputShape.widthRange());
        }

    }


#if COREML_VALIDATOR_VERBOSE
    std::cout << "Broadcast layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Broadcast layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}


void NeuralNetworkShaper::shapeDotLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape1 = blobShapes[specLayer.input(0)];
    ShapeConstraint& inputShape2 = blobShapes[specLayer.input(1)];
    ShapeConstraint outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));
#if COREML_VALIDATOR_VERBOSE
    std::cout << "Dot layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape1;
    std::cout << inputShape2;
    std::cout << "Dot layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape1.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape1.batchRange()));
    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape2.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape2.batchRange()));

    inputShape1.setHeight(1);
    inputShape1.setWidth(1);
    inputShape2.setHeight(1);
    inputShape2.setWidth(1);

    // The inputs need to be equal
    inputShape1.copyFrom(inputShape2);
    inputShape2.copyFrom(inputShape1);

    inputShape1.updateSequenceRange(outputShape.sequenceRange());
    inputShape1.updateBatchRange(outputShape.batchRange());

    inputShape2.updateSequenceRange(outputShape.sequenceRange());
    inputShape2.updateBatchRange(outputShape.batchRange());

    outputShape.setChannel(1);
    outputShape.setHeight(1);
    outputShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Dot layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape1;
    std::cout << inputShape2;
    std::cout << "Dot layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif


}

void NeuralNetworkShaper::shapeReduceLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reduce layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reduce layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    Specification::ReduceLayerParams reduce = specLayer.reduce();

    switch (reduce.axis()) {
        case Specification::ReduceLayerParams::CHW: {
            outputShape.setChannel(1);
            outputShape.setHeight(1);
            outputShape.setWidth(1);
            break;
        }
        case Specification::ReduceLayerParams::HW: {
            outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));
            inputShape.updateChannelRange(outputShape.channelRange());
            outputShape.setHeight(1);
            outputShape.setWidth(1);
            break;
        }
        case Specification::ReduceLayerParams::H: {
            outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));
            inputShape.updateChannelRange(outputShape.channelRange());
            outputShape.setHeight(1);
            outputShape.updateWidthRange(outputShape.widthRange().intersect(inputShape.widthRange()));
            inputShape.updateWidthRange(outputShape.widthRange());
            break;
        }
        case Specification::ReduceLayerParams::W: {
            outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));
            inputShape.updateChannelRange(outputShape.channelRange());
            outputShape.updateHeightRange(outputShape.heightRange().intersect(inputShape.heightRange()));
            inputShape.updateHeightRange(outputShape.heightRange());
            outputShape.setWidth(1);
            break;
        }
        case Specification::ReduceLayerParams::C: {
            outputShape.setChannel(1);
            outputShape.updateHeightRange(outputShape.heightRange().intersect(inputShape.heightRange()));
            inputShape.updateHeightRange(outputShape.heightRange());
            outputShape.updateWidthRange(outputShape.widthRange().intersect(inputShape.widthRange()));
            outputShape.updateWidthRange(outputShape.widthRange());
            break;
        }
        case CoreML::Specification::ReduceLayerParams_ReduceAxis_ReduceLayerParams_ReduceAxis_INT_MIN_SENTINEL_DO_NOT_USE_:
        case CoreML::Specification::ReduceLayerParams_ReduceAxis_ReduceLayerParams_ReduceAxis_INT_MAX_SENTINEL_DO_NOT_USE_: {
            throw std::runtime_error("Reduce layer axis not set -- should have been caught in validator.");
            break;
        }
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reduce layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reduce layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeLoadConstantLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Load Constant layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    Specification::LoadConstantLayerParams load_constant = specLayer.loadconstant();

    outputShape.setSequence(1);
    outputShape.setBatch(1);
    outputShape.setChannel((size_t)load_constant.shape(0));
    outputShape.setHeight((size_t)load_constant.shape(1));
    outputShape.setWidth((size_t)load_constant.shape(2));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Load Constant layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeReshapeLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reshape layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reshape layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    // TODO: there is some cross shape constraint here -- all 3 or 4 input ranges might be unconstrained, but now they are constrained to a combination whose product is
    // equal to the product of the output shapes

    Specification::ReshapeLayerParams reshape = specLayer.reshape();

    outputShape.updateBatchRange(inputShape.batchRange());

    if (reshape.targetshape_size() == 3) {
        outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
        outputShape.setChannel(static_cast<size_t>(reshape.targetshape(0)));
        outputShape.setHeight(static_cast<size_t>(reshape.targetshape(1)));
        outputShape.setWidth(static_cast<size_t>(reshape.targetshape(2)));
    } else {
        outputShape.updateSequenceRange(static_cast<size_t>(reshape.targetshape(0)));
        outputShape.setChannel(static_cast<size_t>(reshape.targetshape(1)));
        outputShape.setHeight(static_cast<size_t>(reshape.targetshape(2)));
        outputShape.setWidth(static_cast<size_t>(reshape.targetshape(3)));
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reshape layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reshape layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeFlattenLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Flatten layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Flatten layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    outputShape.updateChannelRange(inputShape.channelRange() * inputShape.heightRange() * inputShape.widthRange());
    outputShape.setHeight(1);
    outputShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Flatten layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Flatten layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapePermuteLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

    std::vector<ShapeRange> ranges(4);
    ranges[0] = inputShape.sequenceRange();
    ranges[1] = inputShape.channelRange();
    ranges[2] = inputShape.heightRange();
    ranges[3] = inputShape.widthRange();

    Specification::PermuteLayerParams permute = specLayer.permute();
    size_t axis0 = static_cast<size_t>(permute.axis(0));
    size_t axis1 = static_cast<size_t>(permute.axis(1));
    size_t axis2 = static_cast<size_t>(permute.axis(2));
    size_t axis3 = static_cast<size_t>(permute.axis(3));
    
    // Check that indices into "ranges" and "outranges" are not out of bounds.
    if (axis0 > 3 || axis1 > 3 || axis2 > 3 || axis3 > 3) {
        throw std::runtime_error("Ranges axis index is out of bounds in shapePermuteLayer.");
    }

    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Permute layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Permute layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.setName(specLayer.output(0));
    outputShape.updateSequenceRange(ranges[axis0]);
    outputShape.updateChannelRange(ranges[axis1]);
    outputShape.updateHeightRange(ranges[axis2]);
    outputShape.updateWidthRange(ranges[axis3]);

    // Need for the reverse pass as well
    std::vector<ShapeRange> outranges(4);
    outranges[axis0] = outputShape.sequenceRange();
    outranges[axis1] = outputShape.channelRange();
    outranges[axis2] = outputShape.heightRange();
    outranges[axis3] = outputShape.widthRange();

    // Need the constraints to go both ways here
    inputShape.updateSequenceRange(outranges[0]);
    inputShape.updateChannelRange(outranges[1]);
    inputShape.updateHeightRange(outranges[2]);
    inputShape.updateWidthRange(outranges[3]);

    inputShape.updateBatchRange(outputShape.batchRange());

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Permute layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Permute layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif
    
}

void NeuralNetworkShaper::shapeConcatLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Concat layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Concat layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    ShapeRange outShape;
    if (specLayer.concat().sequenceconcat()) {
        outShape = inputShape.sequenceRange();
        outputShape.updateChannelRange(inputShape.channelRange());
    }
    else {
        outShape = inputShape.channelRange();
        outputShape.updateSequenceRange(inputShape.sequenceRange());
    }
    outputShape.updateBatchRange(inputShape.batchRange());
    outputShape.updateHeightRange(inputShape.heightRange());
    outputShape.updateWidthRange(inputShape.widthRange());

    for (int i = 1; i < specLayer.input_size(); i++) {
        ShapeConstraint& inputShapei = blobShapes[specLayer.input(i)];
        if (specLayer.concat().sequenceconcat()) {
            outShape = outShape + inputShapei.sequenceRange();
            outputShape.updateChannelRange(inputShapei.channelRange());
            inputShapei.updateChannelRange(outputShape.channelRange());
        }
        else {
            outShape = outShape + inputShapei.channelRange();
            outputShape.updateSequenceRange(inputShapei.sequenceRange());
            inputShapei.updateSequenceRange(outputShape.sequenceRange());
        }
        outputShape.updateBatchRange(inputShapei.batchRange());
        inputShapei.updateBatchRange(outputShape.batchRange());

        outputShape.updateHeightRange(inputShapei.heightRange());
        inputShapei.updateHeightRange(outputShape.heightRange());

        outputShape.updateWidthRange(inputShapei.widthRange());
        inputShapei.updateWidthRange(outputShape.widthRange());

    }

    if (specLayer.concat().sequenceconcat()) {
        outputShape.updateSequenceRange(outShape);
    }
    else {
        outputShape.updateChannelRange(outShape);
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Concat layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Concat layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeSplitLayer(const Specification::NeuralNetworkLayer& specLayer) {

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Split layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Split layer " << specLayer.name() << " output shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        ShapeConstraint& outputShape = blobShapes[specLayer.output(i)];
        outputShape.setName(specLayer.output(i));
        std::cout << outputShape;
    }
#endif

    int nout = static_cast<int>(specLayer.split().noutputs());

    for (int i = 0; i < specLayer.output_size(); i++) {

        ShapeConstraint& outputShape = blobShapes[specLayer.output(i)];
        outputShape.setName(specLayer.output(i));

        outputShape.updateSequenceRange(inputShape.sequenceRange());
        outputShape.updateBatchRange(inputShape.batchRange());
        outputShape.updateChannelRange(inputShape.channelRange() / nout);
        outputShape.updateHeightRange(inputShape.heightRange());
        outputShape.updateWidthRange(inputShape.widthRange());

    }
    
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange() * static_cast<size_t>(nout));
    inputShape.updateHeightRange(outputShape.heightRange());
    inputShape.updateWidthRange(outputShape.widthRange());

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Split layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Split layer " << specLayer.name() << " output shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        ShapeConstraint& outputShape = blobShapes[specLayer.output(i)];
        outputShape.setName(specLayer.output(i));
        std::cout << outputShape;
    }
#endif


}

void NeuralNetworkShaper::shapeSequenceRepeatLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Sequence Repeat layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Sequence Repeat layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

    size_t nout = (size_t)specLayer.sequencerepeat().nrepetitions();
    outputShape.updateSequenceRange(inputShape.sequenceRange() * nout);
    outputShape.updateBatchRange(inputShape.batchRange());
    outputShape.updateChannelRange(inputShape.channelRange());
    outputShape.updateHeightRange(inputShape.heightRange());
    outputShape.updateWidthRange(inputShape.widthRange());

    // Pass constraints back up as well
    inputShape.updateBatchRange(outputShape.batchRange());
    inputShape.updateChannelRange(outputShape.channelRange());
    inputShape.updateHeightRange(outputShape.heightRange());
    inputShape.updateWidthRange(outputShape.widthRange());

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Sequence Repeat layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Sequence Repeat layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeReorganizeDataLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reorganize Data layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reorganize Data layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    Specification::ReorganizeDataLayerParams reorg = specLayer.reorganizedata();
    size_t blockSize = (size_t)reorg.blocksize();

    if (reorg.mode() == Specification::ReorganizeDataLayerParams::SPACE_TO_DEPTH){
        outputShape.updateChannelRange(inputShape.channelRange() * (blockSize * blockSize));
        outputShape.updateHeightRange(inputShape.heightRange() / blockSize);
        outputShape.updateWidthRange(inputShape.widthRange() / blockSize);
        inputShape.updateChannelRange(outputShape.channelRange() / (blockSize * blockSize));
        inputShape.updateHeightRange(outputShape.heightRange() * blockSize);
        inputShape.updateWidthRange(outputShape.widthRange() * blockSize);
    } else {
        outputShape.updateChannelRange(inputShape.channelRange() / (blockSize * blockSize));
        outputShape.updateHeightRange(inputShape.heightRange() * blockSize);
        outputShape.updateWidthRange(inputShape.widthRange() * blockSize);
        inputShape.updateChannelRange(outputShape.channelRange() * (blockSize * blockSize));
        inputShape.updateHeightRange(outputShape.heightRange() / blockSize);
        inputShape.updateWidthRange(outputShape.widthRange() / blockSize);
    }

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Reorganize Data layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Reorganize Data layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeSliceLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Slice layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Slice layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    outputShape.updateSequenceRange(inputShape.sequenceRange());
    outputShape.updateBatchRange(inputShape.batchRange());

    inputShape.updateSequenceRange(outputShape.sequenceRange());
    inputShape.updateBatchRange(outputShape.batchRange());

    Specification::SliceLayerParams slice = specLayer.slice();
    int start = static_cast<int>(slice.startindex());
    int end = static_cast<int>(slice.endindex());
    int stride = static_cast<int>(slice.stride());
    auto axis = slice.axis();

    int outsize = 0;
    int inLowerBound = 0;
    bool fixedSize = false;

    if (start >= 0 && end > 0) {
        fixedSize = true;
        outsize = (end - 1 - start) / stride + 1;
        inLowerBound = end;
    }
    else if (start < 0 && end <= 0) {
        fixedSize = true;
        inLowerBound = -1*start;
        outsize = (-start - 1 + end) / stride + 1;
    }

    // TODO: add check that this size is possible from the input size

    if (fixedSize) {
        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS:
                outputShape.setChannel(static_cast<size_t>(outsize));
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            case Specification::SliceLayerParams::HEIGHT_AXIS:
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.setHeight(static_cast<size_t>(outsize));
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            case Specification::SliceLayerParams::WIDTH_AXIS:
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.setWidth(static_cast<size_t>(outsize));

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));

                break;
            default:
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
        }
    }
    else {
        ShapeRange size;
        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS:
                size = inputShape.channelRange();
                break;
            case Specification::SliceLayerParams::HEIGHT_AXIS:
                size = inputShape.heightRange();
                break;
            case Specification::SliceLayerParams::WIDTH_AXIS:
                size = inputShape.widthRange();
                break;
            default:
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
        }

        if (end <= 0) {
            end = (-1*end);
        }
        else { // start <=0, we already took the case where they have the same size
            start = (-1*start);
        }

        ShapeRange outrange = (size - (start + 1 + end)) / stride + 1;

        inLowerBound = start + 1 + end;

        switch (axis) {
            case Specification::SliceLayerParams::CHANNEL_AXIS: {
                outputShape.updateChannelRange(outrange);
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.lowerBoundChannel(static_cast<size_t>(inLowerBound));
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            }
            case Specification::SliceLayerParams::HEIGHT_AXIS: {
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(outrange);
                outputShape.updateWidthRange(inputShape.widthRange());

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.lowerBoundHeight(static_cast<size_t>(inLowerBound));
                inputShape.updateWidthRange(outputShape.widthRange());

                break;
            }
            case Specification::SliceLayerParams::WIDTH_AXIS: {
                outputShape.updateChannelRange(inputShape.channelRange());
                outputShape.updateHeightRange(inputShape.heightRange());
                outputShape.updateWidthRange(outrange);

                inputShape.updateChannelRange(outputShape.channelRange());
                inputShape.updateHeightRange(outputShape.heightRange());
                inputShape.lowerBoundWidth(static_cast<size_t>(inLowerBound));

                break;
            }
            default: {
                throw std::runtime_error("Slice layer axis incorrect -- should be caught in validator.");
                break;
            }
        }

    } // signs don't match

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Slice layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Slice layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeSimpleRecurrentLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

    ShapeConstraint& stateInShape = blobShapes[specLayer.input(1)];
    ShapeConstraint& stateOutShape = blobShapes[specLayer.output(1)];
    stateOutShape.setName(specLayer.output(1));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Simple Recurrent layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Simple Recurrent layer " << specLayer.name() << " output shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

    const Specification::SimpleRecurrentLayerParams& recurrent = specLayer.simplerecurrent();

    size_t inSize = (size_t)recurrent.inputvectorsize();
    size_t outSize = (size_t)recurrent.outputvectorsize();

    inputShape.setChannel(inSize);
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel(outSize);
    outputShape.setHeight(1);
    outputShape.setWidth(1);

    if (recurrent.sequenceoutput()) {
        outputShape.updateSequenceRange(inputShape.sequenceRange());
    }
    else {
        outputShape.setSequence(1);
    }

    stateInShape.setSequence(1);
    stateInShape.setChannel(outSize);
    stateInShape.setHeight(1);
    stateInShape.setWidth(1);

    stateOutShape.setSequence(1);
    stateOutShape.setChannel(outSize);
    stateOutShape.setHeight(1);
    stateOutShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Simple Recurrent layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Simple Recurrent layer " << specLayer.name() << " output shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

}

void NeuralNetworkShaper::shapeGRULayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "GRU layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "GRU layer " << specLayer.name() << " output shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

    const Specification::GRULayerParams& recurrent = specLayer.gru();

    size_t inSize = (size_t)recurrent.inputvectorsize();
    size_t outSize = (size_t)recurrent.outputvectorsize();

    inputShape.setChannel(inSize);
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel(outSize);
    outputShape.setHeight(1);
    outputShape.setWidth(1);
    if (recurrent.sequenceoutput()) {
        outputShape.updateSequenceRange(inputShape.sequenceRange());
    }
    else {
        outputShape.setSequence(1);
    }

    if (specLayer.input_size() < 2) {
        return;
    }

    ShapeConstraint& stateInShape = blobShapes[specLayer.input(1)];
    ShapeConstraint& stateOutShape = blobShapes[specLayer.output(1)];
    stateOutShape.setName(specLayer.output(1));

    stateInShape.setSequence(1);
    stateInShape.setChannel(outSize);
    stateInShape.setHeight(1);
    stateInShape.setWidth(1);

    stateOutShape.setSequence(1);
    stateOutShape.setChannel(outSize);
    stateOutShape.setHeight(1);
    stateOutShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "GRU layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "GRU layer " << specLayer.name() << " output shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

}

void NeuralNetworkShaper::shapeUnidirectionalLSTMLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Unidirectional LSTM layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Unidirectional LSTM layer " << specLayer.name() << " output shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

    const Specification::UniDirectionalLSTMLayerParams& recurrent = specLayer.unidirectionallstm();

    size_t inSize = (size_t)recurrent.inputvectorsize();
    size_t outSize = (size_t)recurrent.outputvectorsize();

    inputShape.setChannel(inSize);
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel(outSize);
    outputShape.setHeight(1);
    outputShape.setWidth(1);
    if (recurrent.params().sequenceoutput()) {
        outputShape.updateSequenceRange(inputShape.sequenceRange());
    }
    else {
        outputShape.setSequence(1);
    }

    if (specLayer.input_size() < 3) {
        return;
    }

    ShapeConstraint& stateInShape = blobShapes[specLayer.input(1)];
    ShapeConstraint& stateOutShape = blobShapes[specLayer.output(1)];
    stateOutShape.setName(specLayer.output(1));

    ShapeConstraint& hiddenInShape = blobShapes[specLayer.input(2)];
    ShapeConstraint& hiddenOutShape = blobShapes[specLayer.output(2)];
    hiddenOutShape.setName(specLayer.output(2));

    stateInShape.setSequence(1);
    stateInShape.setChannel(outSize);
    stateInShape.setHeight(1);
    stateInShape.setWidth(1);

    hiddenInShape.setSequence(1);
    hiddenInShape.setChannel(outSize);
    hiddenInShape.setHeight(1);
    hiddenInShape.setWidth(1);

    stateOutShape.setSequence(1);
    stateOutShape.setChannel(outSize);
    stateOutShape.setHeight(1);
    stateOutShape.setWidth(1);

    hiddenOutShape.setSequence(1);
    hiddenOutShape.setChannel(outSize);
    hiddenOutShape.setHeight(1);
    hiddenOutShape.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Unidirectional LSTM layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Unidirectional LSTM layer " << specLayer.name() << " output shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

}

void NeuralNetworkShaper::shapeBidirectionalLSTMLayer(const Specification::NeuralNetworkLayer& specLayer) {

    //get the input shape
    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Bidirectional LSTM layer " << specLayer.name() << " input shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Bidirectional LSTM layer " << specLayer.name() << " output shapes (before): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

    const Specification::BiDirectionalLSTMLayerParams& recurrent = specLayer.bidirectionallstm();

    size_t inSize = (size_t)recurrent.inputvectorsize();
    size_t outSize = (size_t)recurrent.outputvectorsize();

    // This is the current maximum sequence length for bidirectional models.
    inputShape.upperBoundSequence(10000);
    
    inputShape.setChannel(inSize);
    inputShape.setHeight(1);
    inputShape.setWidth(1);

    outputShape.setChannel(2*outSize);
    outputShape.setHeight(1);
    outputShape.setWidth(1);
    if (recurrent.params().sequenceoutput()) {
        outputShape.updateSequenceRange(inputShape.sequenceRange());
    }
    else {
        outputShape.setSequence(1);
    }

    if (specLayer.input_size() < 5) {
        return;
    }

    ShapeConstraint& stateInShape = blobShapes[specLayer.input(1)];
    ShapeConstraint& stateOutShape = blobShapes[specLayer.output(1)];
    stateOutShape.setName(specLayer.output(1));

    ShapeConstraint& hiddenInShape = blobShapes[specLayer.input(2)];
    ShapeConstraint& hiddenOutShape = blobShapes[specLayer.output(2)];
    hiddenOutShape.setName(specLayer.output(2));

    ShapeConstraint& stateInShapeRev = blobShapes[specLayer.input(3)];
    ShapeConstraint& stateOutShapeRev = blobShapes[specLayer.output(3)];
    stateOutShapeRev.setName(specLayer.output(3));

    ShapeConstraint& hiddenInShapeRev = blobShapes[specLayer.input(4)];
    ShapeConstraint& hiddenOutShapeRev = blobShapes[specLayer.output(4)];
    hiddenOutShapeRev.setName(specLayer.output(4));


    stateInShape.setSequence(1);
    stateInShape.setChannel(outSize);
    stateInShape.setHeight(1);
    stateInShape.setWidth(1);

    hiddenInShape.setSequence(1);
    hiddenInShape.setChannel(outSize);
    hiddenInShape.setHeight(1);
    hiddenInShape.setWidth(1);

    stateOutShape.setSequence(1);
    stateOutShape.setChannel(outSize);
    stateOutShape.setHeight(1);
    stateOutShape.setWidth(1);

    hiddenOutShape.setSequence(1);
    hiddenOutShape.setChannel(outSize);
    hiddenOutShape.setHeight(1);
    hiddenOutShape.setWidth(1);

    stateInShapeRev.setSequence(1);
    stateInShapeRev.setChannel(outSize);
    stateInShapeRev.setHeight(1);
    stateInShapeRev.setWidth(1);

    hiddenInShapeRev.setSequence(1);
    hiddenInShapeRev.setChannel(outSize);
    hiddenInShapeRev.setHeight(1);
    hiddenInShapeRev.setWidth(1);

    stateOutShapeRev.setSequence(1);
    stateOutShapeRev.setChannel(outSize);
    stateOutShapeRev.setHeight(1);
    stateOutShapeRev.setWidth(1);

    hiddenOutShapeRev.setSequence(1);
    hiddenOutShapeRev.setChannel(outSize);
    hiddenOutShapeRev.setHeight(1);
    hiddenOutShapeRev.setWidth(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Bidirectional LSTM layer " << specLayer.name() << " input shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.input_size(); i++) {
        std::cout << blobShapes[specLayer.input(i)];
    }
    std::cout << "Bidirectional LSTM layer " << specLayer.name() << " output shapes (after): " << std::endl;
    for (int i = 0; i < specLayer.output_size(); i++) {
        std::cout << blobShapes[specLayer.output(i)];
    }
#endif

}

// Without the implementation, we can't constrain the inputs or outputs
void NeuralNetworkShaper::shapeCustomLayer(const Specification::NeuralNetworkLayer& specLayer) {

    for (int i = 0; i < specLayer.output_size(); i++) {
        // just set them all to be unconstrained
        ShapeConstraint& outputShape = blobShapes[specLayer.output(i)];
        outputShape.setName(specLayer.output(i));
    }

}

void NeuralNetworkShaper::shapeResizeBilinearLayer(const Specification::NeuralNetworkLayer& specLayer) {

    // Input shape: [Seq, B, C, H_in, W_out]
    // Output shape: [Seq, B, C, H_out, W_out]

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Resize Bilinear layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Resize Bilinear layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    // For forward pass: update the output shape ranges
    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(inputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));
    outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));

    Specification::ResizeBilinearLayerParams resize = specLayer.resizebilinear();

    size_t target_h = 1;
    size_t target_w = 1;

    if (resize.targetsize_size() == 2) {
        target_h = (resize.targetsize(0) == 0) ? 1 : (size_t)resize.targetsize(0); //height
        target_w = (resize.targetsize(1) == 0) ? 1 : (size_t)resize.targetsize(1); //width
    }

    outputShape.setHeight(target_h);
    outputShape.setWidth(target_w);

    // For backward pass: update the input shape ranges
    inputShape.updateSequenceRange(inputShape.sequenceRange().intersect(outputShape.sequenceRange()));
    inputShape.updateBatchRange(inputShape.batchRange().intersect(outputShape.batchRange()));
    inputShape.updateChannelRange(inputShape.channelRange().intersect(outputShape.channelRange()));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Resize Bilinear layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Resize Bilinear layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::shapeCropResizeLayer(const Specification::NeuralNetworkLayer& specLayer) {

    // Input shapes: [1, B, C, H_in, W_out], roi shape = [N, 1, 4/5, 1, 1]
    // Output shape: [N, B, C, H_out, W_out] if roi shape = [N, 1, 4, 1, 1]
    // Output shape: [N, 1, C, H_out, W_out] if roi shape = [N, 1, 5, 1, 1]

    ShapeConstraint& inputShape = blobShapes[specLayer.input(0)];
    ShapeConstraint& roiInputShape = blobShapes[specLayer.input(1)];
    ShapeConstraint& outputShape = blobShapes[specLayer.output(0)];
    outputShape.setName(specLayer.output(0));

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Crop resize layer " << specLayer.name() << " input shapes (before): " << std::endl;
    std::cout << inputShape;
    std::cout << "Crop resize layer " << specLayer.name() << " output shapes (before): " << std::endl;
    std::cout << outputShape;
#endif

    // For forward pass
    outputShape.updateSequenceRange(outputShape.sequenceRange().intersect(roiInputShape.sequenceRange()));
    outputShape.updateBatchRange(outputShape.batchRange().intersect(inputShape.batchRange()));
    outputShape.updateChannelRange(outputShape.channelRange().intersect(inputShape.channelRange()));

    Specification::CropResizeLayerParams crop_resize = specLayer.cropresize();

    size_t target_h = 1;
    size_t target_w = 1;

    if (crop_resize.targetsize_size() == 2) {
        target_h = (crop_resize.targetsize(0) == 0) ? 1 : (size_t)crop_resize.targetsize(0); //height
        target_w = (crop_resize.targetsize(1) == 0) ? 1 : (size_t)crop_resize.targetsize(1); //width
    }

    outputShape.setHeight(target_h);
    outputShape.setWidth(target_w);

    // For Backward pass
    // image input:
    inputShape.updateChannelRange(inputShape.channelRange().intersect(outputShape.channelRange()));

    // roi input:
    roiInputShape.updateSequenceRange(roiInputShape.sequenceRange().intersect(outputShape.sequenceRange()));
    roiInputShape.setBatch(1);
    roiInputShape.updateChannelRange(ShapeRange(4, 5));
    roiInputShape.setWidth(1);
    roiInputShape.setHeight(1);

#if COREML_VALIDATOR_VERBOSE
    std::cout << "Crop resize layer " << specLayer.name() << " input shapes (after): " << std::endl;
    std::cout << inputShape;
    std::cout << "Crop resize layer " << specLayer.name() << " output shapes (after): " << std::endl;
    std::cout << outputShape;
#endif

}

void NeuralNetworkShaper::ProcessLayer(const Specification::NeuralNetworkLayer& layer) {

    switch(layer.layer_case()) {
        case Specification::NeuralNetworkLayer::kConvolution:
            shapeConvolutionLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kPooling:
            shapePoolingLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kActivation:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kInnerProduct:
            shapeInnerProductLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kEmbedding:
            shapeEmbeddingLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kBatchnorm:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kMvn:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kL2Normalize:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kSoftmax:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kLrn:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kCrop:
            shapeCropLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kPadding:
            shapePaddingLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kUpsample:
            shapeUpsampleLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kUnary:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kAdd:
            shapeBroadcastLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kMultiply:
            shapeBroadcastLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kAverage:
            shapeBroadcastLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kScale:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kBias:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kMax:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kMin:
            shapeUnchanged(layer);
            break;
        case Specification::NeuralNetworkLayer::kDot:
            shapeDotLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kReduce:
            shapeReduceLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kLoadConstant:
            shapeLoadConstantLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kReshape:
            shapeReshapeLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kFlatten:
            shapeFlattenLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kPermute:
            shapePermuteLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kConcat:
            shapeConcatLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kSplit:
            shapeSplitLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kSequenceRepeat:
            shapeSequenceRepeatLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kReorganizeData:
            shapeReorganizeDataLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kSlice:
            shapeSliceLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kSimpleRecurrent:
            shapeSimpleRecurrentLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kGru:
            shapeGRULayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kUniDirectionalLSTM:
            shapeUnidirectionalLSTMLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kBiDirectionalLSTM:
            shapeBidirectionalLSTMLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kCustom:
            shapeCustomLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kResizeBilinear:
            shapeResizeBilinearLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::kCropResize:
            shapeCropResizeLayer(layer);
            break;
        case Specification::NeuralNetworkLayer::LAYER_NOT_SET:
            throw std::runtime_error("Layer type not found.");
            break;
        default:
            throw std::runtime_error("Shape inference not implemented for this layer type.");
            break;
    }
}

void NeuralNetworkShaper::PassColorsDown(const Specification::NeuralNetworkLayer& layer) {
    for (const auto& inname: layer.input()) {
        const std::set<int>& inset = blobColors[inname];
        for (const auto& outname: layer.output()) {
            std::set<int>& outset = blobColors[outname];
            for (int color: inset) {
                outset.insert(color);
            }
        }
    }
}

void NeuralNetworkShaper::PassColorsUp(const Specification::NeuralNetworkLayer& layer) {
    for (const auto& outname: layer.output()) {
        const std::set<int>& outset = blobColors[outname];
        for (const auto& inname: layer.input()) {
            std::set<int>& inset = blobColors[inname];
            for (int color: outset) {
                inset.insert(color);
            }
        }
    }
}

bool NeuralNetworkShaper::AllShapesDone() {
    for (auto it = blobColors.begin(); it != blobColors.end(); it++) {
        if (static_cast<int>(it->second.size()) != numColors) {
            return false;
        }
    }
    return true;
}

NeuralNetworkShaper::NeuralNetworkShaper(const Specification::ModelDescription& interface, const google::protobuf::RepeatedPtrField<Specification::NeuralNetworkLayer>& layers, bool useInputAndOutputConstraints)
    :
    numColors(interface.input().size())
{

    for (int i = 0; i < interface.input().size(); i++) {
        const Specification::FeatureDescription& desc = interface.input(i);

        ShapeConstraint& constraint = blobShapes[desc.name()];
        constraint.setName(desc.name());
        // Each input blob starts as it's own color
        blobColors[desc.name()].insert(i);

        if (useInputAndOutputConstraints) {
            constraint.updateConstraint(desc.type());
        }

    }

    bool done = false;

    while (!done) {

        std::map<std::string, std::set<int> > blobColorsCopy = blobColors;

#if COREML_VALIDATOR_VERBOSE
        std::cout << std::endl << "=====================" << std::endl << "Computing neural network shapes" << std::endl  << "=====================" << std::endl;
        std::cout << std::endl << "Starting forward computation of neural network shapes" << std::endl << std::endl;
#endif

        // forward pass
        for (int i = 0; i < layers.size(); i++){
            const Specification::NeuralNetworkLayer& layer = layers[i];
            PassColorsDown(layer);
            ProcessLayer(layer);
        } // loop over layers

#if COREML_VALIDATOR_VERBOSE
        std::cout << std::endl << std::endl << "Starting backward computation of neural network shapes" << std::endl << std::endl;
#endif

        // backward pass
        for (int i = layers.size() - 1; i >= 0; i--){
            const Specification::NeuralNetworkLayer& layer = layers[i];
            PassColorsUp(layer);
            ProcessLayer(layer);
        } // loop over layers

        done = AllShapesDone();

        if (blobColorsCopy == blobColors) {
            break;
        }
    }

    if (useInputAndOutputConstraints) {
        for (int i = 0; i < interface.output().size(); i++) {
            const Specification::FeatureDescription& desc = interface.output(i);

            // skip over names that are just for the classifier
            if (desc.name().compare(interface.predictedprobabilitiesname()) == 0
                || desc.name().compare(interface.predictedfeaturename()) == 0) {
                continue;
            }
            
            // using at because it needs to exist, this will throw if it doesn't
            if (blobShapes.find(desc.name()) == blobShapes.end()) {
                continue;
            }
            ShapeConstraint& constraint = blobShapes.at(desc.name());

            // TODO: add a catch with an error message that mentions the name

            if (desc.type().Type_case() == Specification::FeatureType::kImageType) {
                // sequence constraint here is unbounded
                // batch is unbounded
                // other three read from the constraint as is -- later to be updated with flexibility
                if (desc.type().imagetype().colorspace() == Specification::ImageFeatureType_ColorSpace_GRAYSCALE)
                    constraint.setChannel(1);
                else {
                    constraint.setChannel(3);
                }
                constraint.setHeight(static_cast<size_t>(desc.type().imagetype().height()));
                constraint.setWidth(static_cast<size_t>(desc.type().imagetype().width()));
            }
            else { // assuming it's a multi array

                // allowing for the possibility that output shapes aren't constrained
                if (desc.type().multiarraytype().shape_size() == 3) {
                    constraint.setChannel(static_cast<size_t>(desc.type().multiarraytype().shape(0)));
                    constraint.setHeight(static_cast<size_t>(desc.type().multiarraytype().shape(1)));
                    constraint.setWidth(static_cast<size_t>(desc.type().multiarraytype().shape(2)));
                }
                else if (desc.type().multiarraytype().shape_size() == 1) {
                    constraint.setChannel(static_cast<size_t>(desc.type().multiarraytype().shape(0)));
                    constraint.setHeight(1);
                    constraint.setWidth(1);
                }
            }
        }
    }

}

NeuralNetworkShaper::NeuralNetworkShaper(const Specification::Model& model, bool useInputAndOutputConstraints)
:
NeuralNetworkShaper(model.description(), *(getNNSpec(model)), useInputAndOutputConstraints)
{}

const ShapeConstraint& NeuralNetworkShaper::shape(const std::string& name) const {
    const ShapeConstraint& retval = blobShapes.at(name);
    return retval;
}

bool NeuralNetworkShaper::isValid() const {
    // TODO: make this check all of the blob names, and make it compare to the input and output shapes
    for (auto ind = blobShapes.begin(); ind != blobShapes.end(); ind++) {
        if (!(*ind).second.isValid())
            return false;
    }
    return true;
}

void NeuralNetworkShaper::print() const {

    std::cout << "Network Shapes: " << std::endl;
    for (auto ind = blobShapes.begin(); ind != blobShapes.end(); ind++) {
        std::cout << ind->second;
    }
    std::cout << std::endl << std::endl;

}




